{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from fastai.vision import *\n",
    "import torch\n",
    "import warnings\n",
    "from torchsummary import summary\n",
    "# from models.custom_resnet import _resnet, Bottleneck\n",
    "from models.resnet_cifar import *\n",
    "from utils import _get_accuracy\n",
    "torch.cuda.set_device(0)\n",
    "torch.manual_seed(1)\n",
    "torch.cuda.manual_seed(1)\n",
    "\n",
    "# stage should be in 0 to 6 (5 for classifier stage, 6 for finetuning stage)\n",
    "hyper_params = {\n",
    "    \"dataset\": 'cifar10',\n",
    "    \"stage\": 6,\n",
    "    \"repeated\": 0,\n",
    "    \"num_classes\": 10,\n",
    "    \"batch_size\": 64,\n",
    "    \"num_epochs\": 2,\n",
    "    \"learning_rate\": 1e-6\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.CIFAR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfms = get_transforms(do_flip=False)\n",
    "data = ImageDataBunch.from_folder(path, train = 'train', valid = 'test', bs = hyper_params[\"batch_size\"], size = 32, ds_tfms = tfms).normalize(cifar_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = '../saved_models/' + hyper_params['dataset'] + '/medium_classifier/model0.pt'\n",
    "net = resnet14_cifar()\n",
    "net.cpu()\n",
    "net.load_state_dict(torch.load(filename, map_location = 'cpu'))\n",
    "if torch.cuda.is_available() : \n",
    "    net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name '_get_accuracy' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-ce9557d3dcf2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     50\u001b[0m     \u001b[0mval_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m     \u001b[0mval_loss_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m     \u001b[0mval_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_accuracy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalid_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'epoch : '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m' / '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhyper_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"num_epochs\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m' | TL : '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m' | VL : '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m' | VA : '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_acc\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name '_get_accuracy' is not defined"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(net.parameters(), lr = hyper_params['learning_rate'])\n",
    "total_step = len(data.train_ds) // hyper_params[\"batch_size\"]\n",
    "train_loss_list = list()\n",
    "val_loss_list = list()\n",
    "min_val = 0\n",
    "savename = '../saved_models/' + str(hyper_params['dataset']) + '/resnet14_classifier/model0.pt'\n",
    "for epoch in range(hyper_params[\"num_epochs\"]):\n",
    "    trn = []\n",
    "    net.train()\n",
    "    for i, (images, labels) in enumerate(data.train_dl) :\n",
    "        if torch.cuda.is_available():\n",
    "            images = torch.autograd.Variable(images).cuda().float()\n",
    "            labels = torch.autograd.Variable(labels).cuda()\n",
    "        else : \n",
    "            images = torch.autograd.Variable(images).float()\n",
    "            labels = torch.autograd.Variable(labels)\n",
    "\n",
    "        y_pred = net(images)\n",
    "\n",
    "        loss = F.cross_entropy(y_pred, labels)\n",
    "        trn.append(loss.item())\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # if i % 50 == 49 :\n",
    "            # print('epoch = ', epoch, ' step = ', i + 1, ' of total steps ', total_step, ' loss = ', loss.item())\n",
    "\n",
    "    train_loss = (sum(trn) / len(trn))\n",
    "    train_loss_list.append(train_loss)\n",
    "\n",
    "    net.eval()\n",
    "    val = []\n",
    "    with torch.no_grad() :\n",
    "        for i, (images, labels) in enumerate(data.valid_dl) :\n",
    "            if torch.cuda.is_available():\n",
    "                images = torch.autograd.Variable(images).cuda().float()\n",
    "                labels = torch.autograd.Variable(labels).cuda()\n",
    "            else : \n",
    "                images = torch.autograd.Variable(images).float()\n",
    "                labels = torch.autograd.Variable(labels)\n",
    "\n",
    "            # Forward pass\n",
    "            y_pred = net(images)\n",
    "\n",
    "            loss = F.cross_entropy(y_pred, labels)\n",
    "            val.append(loss.item())\n",
    "\n",
    "    val_loss = sum(val) / len(val)\n",
    "    val_loss_list.append(val_loss)\n",
    "    val_acc = _get_accuracy(data.valid_dl, net)\n",
    "\n",
    "    print('epoch : ', epoch + 1, ' / ', hyper_params[\"num_epochs\"], ' | TL : ', round(train_loss, 6), ' | VL : ', round(val_loss, 6), ' | VA : ', round(val_acc * 100, 6))\n",
    "\n",
    "    if (val_acc * 100) > min_val :\n",
    "        print('saving model')\n",
    "        min_val = val_acc * 100\n",
    "        torch.save(net.state_dict(), savename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (ak_fastai)",
   "language": "python",
   "name": "ak_fastai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
